{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Watermark Analysis\n",
    "\n",
    "Notebook for performing analysis and visualization of the effects of watermarking schemes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Basic imports\n",
    "import os\n",
    "\n",
    "from tqdm import tqdm\n",
    "from statistics import mean\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from matplotlib import rc\n",
    "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
    "rc('text', usetex=True)\n",
    "\n",
    "import cmasher as cmr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_from_disk"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the processed dataset/frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
    "# save_name = \"analysis_ds_1-21_greedy_redo\" \n",
    "# save_name = \"analysis_ds_1-21_greedy_redo_truncated\"\n",
    "# save_name = \"analysis_ds_1-23_greedy_gamma_0-25_truncated\" \n",
    "# save_name = \"analysis_ds_1-23_greedy_gamma_0-25_0-5_truncated\" # in figure (not 100% sure this is correct, check)\n",
    "\n",
    "# save_name = \"analysis_ds_1-20_more_attack\" # in figure\n",
    "\n",
    "# save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
    "# save_name = \"analysis_ds_1-23_en_1-3\"\n",
    "save_name = \"analysis_ds_1-23_pile_1-3\"\n",
    "\n",
    "save_dir = f\"input/{save_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = load_from_disk(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### convert to pandas df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = raw_data.to_pandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Orig number of rows: {len(df)}\")\n",
    "df.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### \"retokenization\" problem \n",
    "\n",
    "current hypo for what matches this criterion is based on the non 1-to-1 aspect of tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]\n",
    "print(f\"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}\")\n",
    "# retok_problematic_rows"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Replace or drop the the specially marked -1 rows since these are unmeasureable due to short length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_len = len(df)\n",
    "\n",
    "# df['no_bl_whitelist_fraction'].mask(df['no_bl_whitelist_fraction'] == -1.0, pd.NA, inplace=True)\n",
    "# df['w_bl_whitelist_fraction'].mask(df['w_bl_whitelist_fraction'] == -1.0, pd.NA, inplace=True)\n",
    "\n",
    "df = df[df[\"no_bl_whitelist_fraction\"] != -1.0]\n",
    "df = df[df[\"w_bl_whitelist_fraction\"] != -1.0]\n",
    "\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Drop rows where there weren't enough tokens to measure ppl in one or both of the output cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_len = len(df)\n",
    "# df = df[df[\"no_bl_ppl\"].isna()]\n",
    "# df = df[df[\"w_bl_ppl\"].isna()]\n",
    "df = df[~(df[\"no_bl_ppl\"].isna() | df[\"w_bl_ppl\"].isna())]\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### drop rows with really large bias, as 100.0 is $\\simeq \\infty$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_len = len(df)\n",
    "\n",
    "df = df[df[\"bl_logit_bias\"] <= 100.0]\n",
    "\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### drop rows where using sampling but also beam search, not considering at this time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_len = len(df)\n",
    "\n",
    "# df = df[df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]\n",
    "df = df[((df[\"use_sampling\"]==True) & (df[\"num_beams\"] == 1)) | (df[\"use_sampling\"]==False)]\n",
    "\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### correct the sampling temp column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"].fillna(0.0)\n",
    "df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"].fillna(1.0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### marking the hard blacklist rows as having inf/very large bias\n",
    "\n",
    "(after the > 100.0 bias drop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = np.inf\n",
    "# df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = 10000 # crosscheck with whats hardcoded in the bl processor"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Rename some parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"delta\"] = df[\"bl_logit_bias\"].values\n",
    "df[\"gamma\"] = 1 - df[\"bl_proportion\"].values\n",
    "df[\"gamma\"] = df[\"gamma\"].round(3)\n",
    "\n",
    "df[\"no_bl_act_num_wl_tokens\"] = np.round(df[\"no_bl_whitelist_fraction\"].values*df[\"no_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
    "df[\"w_bl_act_num_wl_tokens\"] = np.round(df[\"w_bl_whitelist_fraction\"].values*df[\"w_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
    "\n",
    "df[\"w_bl_std_num_wl_tokens\"] = np.sqrt(df[\"w_bl_var_num_wl_tokens\"].values)\n",
    "\n",
    "if \"real_completion_length\":\n",
    "    df[\"baseline_num_tokens_generated\"] = df[\"real_completion_length\"].values\n",
    "\n",
    "if \"actual_attacked_ratio\" in df.columns:\n",
    "    df[\"actual_attacked_fraction\"] = df[\"actual_attacked_ratio\"].values*df[\"replace_ratio\"].values\n",
    "\n",
    "if \"meta\" in df.columns:\n",
    "    df[\"pile_set_name\"] = df[\"meta\"].apply(lambda dict: dict[\"pile_set_name\"])\n",
    "\n",
    "df[\"baseline_hit_list_length\"] = df[\"baseline_hit_list\"].apply(len)\n",
    "df[\"no_bl_hit_list_length\"] = df[\"no_bl_hit_list\"].apply(len)\n",
    "df[\"w_bl_hit_list_length\"] = df[\"w_bl_hit_list\"].apply(len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for pile outlier filtering\n",
    "df[\"w_bl_space_count\"] = df[\"w_bl_output\"].apply(lambda string: string.count(\" \"))\n",
    "df[\"no_bl_space_count\"] = df[\"no_bl_output\"].apply(lambda string: string.count(\" \"))\n",
    "df[\"baseline_space_count\"] = df[\"baseline_completion\"].apply(lambda string: string.count(\" \"))\n",
    "\n",
    "df[\"w_bl_space_frac\"] = df[\"w_bl_space_count\"].values / df[\"w_bl_hit_list_length\"]\n",
    "df[\"no_bl_space_frac\"] = df[\"no_bl_space_count\"].values / df[\"no_bl_hit_list_length\"]\n",
    "df[\"baseline_space_frac\"] = df[\"baseline_space_count\"].values / df[\"baseline_hit_list_length\"]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Filter for the generation lengths we want to look at"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_len = len(df)\n",
    "\n",
    "# # main filters\n",
    "# # df = df[(df[\"real_completion_length\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)]\n",
    "# df = df[(df[\"gamma\"] == 0.1) | (df[\"gamma\"] == 0.25) | (df[\"gamma\"] == 0.5)]\n",
    "# df = df[(df[\"delta\"] == 1.0) | (df[\"delta\"] == 2.0) | (df[\"delta\"] == 10.0)]\n",
    "# df = df[(df[\"use_sampling\"] == True)]\n",
    "# df = df[(df[\"bl_type\"] == \"soft\")]\n",
    "\n",
    "# df = df[(df[\"real_completion_length\"] == 200) & (df[\"no_bl_num_tokens_generated\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)] # now also applies to the truncated version\n",
    "# df = df[(df[\"no_bl_num_tokens_generated\"] >= 500) & (df[\"w_bl_num_tokens_generated\"] >= 500)] # all gas noop\n",
    "\n",
    "# # # attack specific\n",
    "# df = df[(df[\"real_completion_length\"] == 200) & (df[\"no_bl_num_tokens_generated\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)]\n",
    "# df = df[(df[\"replace_ratio\"] <= 0.7)]\n",
    "\n",
    "# NOTE pile only\n",
    "df = df[df[\"w_bl_space_frac\"] <= 0.9]\n",
    "df = df[df[\"no_bl_space_frac\"] <= 0.9]\n",
    "# df = df[df[\"pile_set_name\"] != \"Github\"]\n",
    "\n",
    "upper_T = 205\n",
    "lower_T = 195\n",
    "df = df[(df[\"baseline_hit_list_length\"] >= lower_T) & (df[\"no_bl_hit_list_length\"] >= lower_T) & (df[\"w_bl_hit_list_length\"] >= lower_T)] # now also applies to the truncated version\n",
    "df = df[(df[\"baseline_hit_list_length\"] <= upper_T) & (df[\"no_bl_hit_list_length\"] <= upper_T) & (df[\"w_bl_hit_list_length\"] <= upper_T)] # now also applies to the truncated version\n",
    "\n",
    "\n",
    "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Add z-scores (convert the raw watermark measurement, fraction, to a z-score )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "import scipy.stats\n",
    "def compute_z_score(observed_wl_frac, T, gamma):\n",
    "    numer = observed_wl_frac - gamma\n",
    "    denom = sqrt(gamma*(1-gamma)/T)\n",
    "    z = numer/denom\n",
    "    return z\n",
    "\n",
    "def compute_wl_for_z(z, T, gamma):\n",
    "    denom = sqrt(gamma*(1-gamma)/T)\n",
    "    numer = ((z*denom)+gamma)*T\n",
    "    return numer\n",
    "\n",
    "def compute_p_value(z):\n",
    "    p_value = scipy.stats.norm.sf(abs(z))\n",
    "    return p_value\n",
    "\n",
    "df[\"baseline_z_score\"] = df[[\"baseline_whitelist_fraction\", \"baseline_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
    "df[\"no_bl_z_score\"] = df[[\"no_bl_whitelist_fraction\", \"no_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
    "df[\"w_bl_z_score\"] = df[[\"w_bl_whitelist_fraction\", \"w_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
    "\n",
    "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
    "    df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if attacked in df\n",
    "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
    "    df[\"w_bl_attacked_act_num_wl_tokens\"] = np.round(df[\"w_bl_attacked_whitelist_fraction\"].values*df[\"w_bl_attacked_num_tokens_generated\"],1) # round to 1 for sanity\n",
    "\n",
    "    df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
    "\n",
    "    df[[\"bl_proportion\",\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\",\"w_bl_attacked_act_num_wl_tokens\", \"w_bl_attacked_z_score\"]]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare groupby (decide which hyperparameters to groups the rows by)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# groupby_fields = ['num_beams', 'max_new_tokens']\n",
    "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens']\n",
    "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens', 'bl_logit_bias']\n",
    "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias']\n",
    "# groupby_fields = ['use_sampling','sampling_temp','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias']\n",
    "# groupby_fields = ['use_sampling','sampling_temp','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias','bl_proportion']\n",
    "# groupby_fields = ['use_sampling','num_beams','bl_type','bl_logit_bias','bl_proportion']\n",
    "\n",
    "if \"w_bl_attacked_whitelist_fraction\" in df.columns: \n",
    "    groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping\n",
    "else:\n",
    "    groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping\n",
    "    # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation\n",
    "    # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### narrowing in on IQ range (not generally used)\n",
    "\n",
    "(removing outliers by subsetting to rows near the mean etc.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tmp_grped_25 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.25).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_25th'})\n",
    "# tmp_grped_50 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.5).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_50th'})\n",
    "# tmp_grped_75 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.75).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_75th'})\n",
    "# df = df.merge(tmp_grped_25, on = groupby_fields)\n",
    "# df = df.merge(tmp_grped_50, on = groupby_fields)\n",
    "# df = df.merge(tmp_grped_75, on = groupby_fields)\n",
    "\n",
    "# # tmp_grped_mean = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].mean().rename(columns={'avg_spike_entropy': 'avg_spike_entropy_mean'})\n",
    "# # tmp_grped_median = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].median().rename(columns={'avg_spike_entropy': 'avg_spike_entropy_median'})\n",
    "# # df = df.merge(tmp_grped_mean, on = groupby_fields)\n",
    "# # df = df.merge(tmp_grped_median, on = groupby_fields)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # eps = 0.001\n",
    "# eps = 0.005\n",
    "# df[\"avg_spike_entropy_mean_minus_eps\"] = df['avg_spike_entropy_mean']-eps\n",
    "# df[\"avg_spike_entropy_mean_plus_eps\"] = df['avg_spike_entropy_mean']+eps\n",
    "\n",
    "# df[\"avg_spike_entropy_median_minus_eps\"] = df['avg_spike_entropy_median']-eps\n",
    "# df[\"avg_spike_entropy_median_plus_eps\"] = df['avg_spike_entropy_median']+eps\n",
    "# print(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # df[[\"avg_spike_entropy_25th\",\"avg_spike_entropy_75th\"]]\n",
    "# df[[\"avg_spike_entropy_mean_minus_eps\",\"avg_spike_entropy_mean\",\"avg_spike_entropy_mean_plus_eps\"]]\n",
    "# df[[\"avg_spike_entropy_median_minus_eps\",\"avg_spike_entropy_median\",\"avg_spike_entropy_median_plus_eps\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# orig_len = len(df)\n",
    "\n",
    "# subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_25th\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_75th\"])]\n",
    "\n",
    "# # subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_mean_minus_eps\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_mean_plus_eps\"])]\n",
    "# # subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_mean_minus_eps\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_mean_plus_eps\"])]\n",
    "\n",
    "# print(f\"Dropped {orig_len-len(subdf)} rows, new len {len(subdf)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# subdf.groupby(groupby_fields)['avg_spike_entropy'].describe()\n",
    "# df.groupby(groupby_fields)['avg_spike_entropy'].describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df = subdf"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perform the groupby (group rows by their hyperparameter settings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_df = df.groupby(groupby_fields)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of rows after filtering: {len(df)}\")\n",
    "print(f\"Number of groups: {len(grouped_df)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loop to compute \"confusion matrix\" (TPR,FPR etc.) at some z scores for tabulation (Table 2 & 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn.metrics as metrics\n",
    "\n",
    "def reject_null_hypo(z_score=None,cuttoff=None):\n",
    "    return z_score > cuttoff\n",
    "\n",
    "records = []\n",
    "\n",
    "for group_params in tqdm(list(grouped_df.groups.keys())):\n",
    "    sub_df = grouped_df.get_group(group_params)\n",
    "    grp_size = len(sub_df)\n",
    "\n",
    "    # baseline_z_scores = sub_df[\"baseline_z_score\"].values\n",
    "    # w_bl_z_scores = sub_df[\"w_bl_z_score\"].values\n",
    "    # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
    "\n",
    "    # baseline_labels = np.zeros_like(baseline_z_scores)\n",
    "    # attacked_labels = np.ones_like(w_bl_z_scores)\n",
    "    # all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
    "\n",
    "    # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
    "    # roc_auc = metrics.auc(fpr, tpr)\n",
    "    record = {k:v for k,v in zip(groupby_fields,group_params)}\n",
    "\n",
    "    for thresh in [4.0,5.0]:\n",
    "        \n",
    "        record[\"count\"] = grp_size\n",
    "        record[f\"baseline_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"baseline_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
    "        record[f\"baseline_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"baseline_z_score\"],cuttoff=thresh)).sum() / grp_size\n",
    "        record[f\"no_bl_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
    "        record[f\"no_bl_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
    "        record[f\"w_bl_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
    "        record[f\"w_bl_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
    "\n",
    "        if \"w_bl_attacked_z_score\" in sub_df.columns:\n",
    "            record[f\"w_bl_attacked_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
    "            record[f\"w_bl_attacked_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
    "\n",
    "    records.append(record)\n",
    "\n",
    "    #     # df[f\"baseline_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"baseline_z_score\"].values,cuttoff=thresh)\n",
    "    #     # df[f\"baseline_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"baseline_z_score\"],cuttoff=thresh)\n",
    "    #     # df[f\"no_bl_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
    "    #     # df[f\"no_bl_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
    "    #     # df[f\"w_bl_tp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
    "    #     # df[f\"w_bl_fn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
    "\n",
    "\n",
    "roc_df = pd.DataFrame.from_records(records)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# thresh = 6.0\n",
    "# thresh = 5.0\n",
    "std_threshes = [4.0, 5.0] #, 6.0]\n",
    "# std_threshes = [4.0]\n",
    "\n",
    "# roc_df[\"params\"] = roc_df.index.to_list()\n",
    "\n",
    "# columns = [\"num_beams\", \"delta\", \"gamma\", \"count\"]\n",
    "# columns = [\"delta\", \"gamma\", \"count\"]\n",
    "columns = [\"use_sampling\",\"delta\", \"gamma\", \"count\"]\n",
    "# columns = [\"use_sampling\", \"replace_ratio\", \"count\"]\n",
    "\n",
    "for thresh in std_threshes:\n",
    "    # columns += [f\"baseline_fpr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\"]\n",
    "    # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"no_bl_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fn_at_{thresh}\"]\n",
    "\n",
    "\n",
    "    # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
    "    \n",
    "    if f\"w_bl_attacked_fnr_at_{thresh}\" in roc_df.columns:\n",
    "        columns += [f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
    "        columns += [f\"w_bl_attacked_tpr_at_{thresh}\",f\"w_bl_attacked_fnr_at_{thresh}\"] # if attack\n",
    "    else:\n",
    "        columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
    "\n",
    "# filter ot not\n",
    "sub_df = roc_df[(roc_df[\"use_sampling\"] == True) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 5.0))  &  ((roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) )]\n",
    "# sub_df = roc_df[(roc_df[\"use_sampling\"] == False) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 5.0))  &  ((roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) ) & (roc_df[\"num_beams\"] == 8)]\n",
    "# sub_df = roc_df[(roc_df[\"replace_ratio\"] == 0.1) | (roc_df[\"replace_ratio\"] == 0.3) | (roc_df[\"replace_ratio\"] == 0.5)  | (roc_df[\"replace_ratio\"] == 0.7)]\n",
    "# sub_df = roc_df[(roc_df[\"num_beams\"] == 8)]\n",
    "# sub_df = roc_df\n",
    "\n",
    "# sub_df.sort_values(\"delta\")[columns]\n",
    "# sub_df.sort_values(\"num_beams\")[columns]\n",
    "sub_df.sort_values(by=[\"delta\",\"gamma\"],ascending=[True, False])[columns]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### write tables to latex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"gamma\").round(3).to_latex(index=False))\n",
    "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"delta\").round(3).to_latex(index=False))\n",
    "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"num_beams\").round(3).to_latex(index=False))\n",
    "\n",
    "# print(sub_df.sort_values(by=[\"delta\",\"gamma\"],ascending=[True, False])[columns].round(3).to_latex(index=False))\n",
    "# print(sub_df.sort_values(\"num_beams\")[columns].round(3).to_latex(index=False))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ROC: No Attack (figure 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.clf()\n",
    "plt.figure(constrained_layout=True)\n",
    "plt.figure(figsize=(5, 4))\n",
    "\n",
    "import sklearn.metrics as metrics\n",
    "\n",
    "zoom = False\n",
    "# zoom = True\n",
    "\n",
    "beam_search = None\n",
    "# beam_search = 1\n",
    "# beam_search = 4\n",
    "# beam_search = 8\n",
    "\n",
    "deltas = [1.0,2.0,5.0,10.0]\n",
    "# gammas = [0.25, 0.5]\n",
    "gammas = [0.25]\n",
    "# gammas = [0.5]\n",
    "\n",
    "# deltas = [1.0,2.0,5.0,10.0]\n",
    "# gammas = [0.1,0.5]\n",
    "\n",
    "groups = []\n",
    "names = []\n",
    "for d in deltas:\n",
    "    for g in gammas:\n",
    "        if beam_search:\n",
    "            groups.append((False, beam_search, d, g))\n",
    "        else:\n",
    "            groups.append((True, 1, d, g))\n",
    "        names.append(f\"$\\delta:{d},\\gamma:{g}$\")\n",
    "groups=groups[::-1]\n",
    "names=names[::-1]\n",
    "\n",
    "# Make colormap\n",
    "import matplotlib.pyplot as plt\n",
    "viridis = plt.colormaps['viridis'].resampled(len(groups)+1) \n",
    "cmap = viridis.colors[:len(groups)][::-1]\n",
    "\n",
    "# plot different parameter levels\n",
    "for i,(group,name) in enumerate(zip(groups,names)):\n",
    "\n",
    "    baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
    "    w_bl_z_scores = grouped_df.get_group(group)[\"w_bl_z_score\"].values\n",
    "    all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
    "\n",
    "    baseline_labels = np.zeros_like(baseline_z_scores)\n",
    "    attacked_labels = np.ones_like(w_bl_z_scores)\n",
    "    all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
    "\n",
    "    fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
    "    roc_auc = metrics.auc(fpr, tpr)\n",
    "\n",
    "    plt.plot(fpr, tpr, color=cmap[i], label = f'{name}, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
    "\n",
    "if \"w_bl_attacked_ppl\" in df.columns:\n",
    "    pass\n",
    "else:\n",
    "    # # vanilla ppl value\n",
    "    plt.scatter([-1],[-1],label=f'            $\\delta=0$, PPL: {round(grouped_df[\"no_bl_ppl\"].describe().loc[groups,\"mean\"].mean(),1)}', color=\"white\")\n",
    "\n",
    "if zoom:\n",
    "    if not \"w_bl_attacked_ppl\" in df.columns:\n",
    "        plt.legend(loc = 'lower right', fontsize = 12)\n",
    "    plt.xscale(\"log\")\n",
    "    # plt.yscale(\"log\")\n",
    "    plt.xlim([0, 1])\n",
    "    plt.ylim([0.5, 1])\n",
    "    plot_name = (\"roc_auc_zoom\" if not beam_search else f\"roc_auc_zoom_greedy_beams_{beam_search}\")\n",
    "\n",
    "else:\n",
    "    if \"w_bl_attacked_ppl\" in df.columns:\n",
    "        plt.legend(loc = 'lower right', fontsize = 12)\n",
    "    plt.plot([0, 1], [0, 1],'r--')\n",
    "    plt.xlim([0, 1])\n",
    "    plt.ylim([0, 1])\n",
    "    plot_name = (\"roc_auc\" if not beam_search else f\"roc_auc_greedy_beams_{beam_search}\")\n",
    "\n",
    "plt.ylabel('True Positive Rate', fontsize = 12)\n",
    "plt.xlabel('False Positive Rate', fontsize = 12)\n",
    "\n",
    "print(plot_name)\n",
    "\n",
    "# fname = f\"figs/{plot_name}.pdf\"\n",
    "# plt.savefig(fname, format=\"pdf\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# ROC: Attack (figure 6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn.metrics as metrics\n",
    "\n",
    "plt.clf()\n",
    "plt.figure(constrained_layout=True)\n",
    "plt.figure(figsize=(5, 4))\n",
    "\n",
    "# attack_budgets = [0.1,0.2,0.3,0.4,0.5,0.6,0.7]\n",
    "attack_budgets = [0.1,0.3,0.5,0.7]\n",
    "groups = [(True, 1, 0.5, 2.0, budget) for budget in attack_budgets]\n",
    "beams = False\n",
    "# groups = [(False, 8, 0.5, 2.0, budget) for budget in attack_budgets]\n",
    "# beams = True\n",
    "\n",
    "names = [f\"$\\epsilon={eps}$\" for eps in attack_budgets]\n",
    "\n",
    "# Make colormap\n",
    "import matplotlib.pyplot as plt\n",
    "viridis = plt.colormaps['viridis'].resampled(len(groups)+1+1) # attack\n",
    "cmap = viridis.colors[:len(groups)+1][::-1]\n",
    "\n",
    "# plot original\n",
    "group = groups[0] # any will do\n",
    "baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
    "baseline_labels = np.zeros_like(baseline_z_scores)\n",
    "\n",
    "orig_watermark_z_scores = grouped_df.get_group(group)[\"w_bl_z_score\"].values\n",
    "watermark_labels = np.ones_like(orig_watermark_z_scores)\n",
    "\n",
    "all_scores = np.concatenate([baseline_z_scores,orig_watermark_z_scores])\n",
    "all_labels = np.concatenate([baseline_labels,watermark_labels])\n",
    "\n",
    "fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
    "roc_auc = metrics.auc(fpr, tpr)\n",
    "\n",
    "plt.plot(fpr, tpr, color=cmap[0], label = f'unattacked, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
    "\n",
    "# plot different attack levels\n",
    "for i,(group,name) in enumerate(zip(groups,names)):\n",
    "\n",
    "    baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
    "    attacked_z_scores = grouped_df.get_group(group)[\"w_bl_attacked_z_score\"].values\n",
    "    all_scores = np.concatenate([baseline_z_scores,attacked_z_scores])\n",
    "\n",
    "    baseline_labels = np.zeros_like(baseline_z_scores)\n",
    "    attacked_labels = np.ones_like(attacked_z_scores)\n",
    "    all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
    "\n",
    "    fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
    "    roc_auc = metrics.auc(fpr, tpr)\n",
    "\n",
    "    plt.plot(fpr, tpr, color=cmap[i+1], label = f'{name}, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_attacked_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
    "\n",
    "if \"w_bl_attacked_ppl\" in df.columns:\n",
    "    pass\n",
    "else:\n",
    "    # # vanilla ppl value\n",
    "    plt.scatter([-1],[-1],label=f'            $\\delta=0$, PPL: {round(grouped_df[\"no_bl_ppl\"].describe().loc[groups,\"mean\"].mean(),1)}', color=\"white\")\n",
    "\n",
    "zoom = False\n",
    "# zoom = True\n",
    "if zoom:\n",
    "    if not \"w_bl_attacked_ppl\" in df.columns:\n",
    "        plt.legend(loc = 'lower right')\n",
    "    plt.xscale(\"log\")\n",
    "    # plt.yscale(\"log\")\n",
    "    plt.xlim([0, 1])\n",
    "    plt.ylim([0.5, 1])\n",
    "    if \"w_bl_attacked_ppl\" in df.columns:\n",
    "        plot_name = \"roc_auc_untargeted_attack_no_beams_zoom\"\n",
    "        # plot_name = \"roc_auc_untargeted_attack_with_beams_zoom\"\n",
    "    else:\n",
    "        plot_name = \"roc_auc_zoom\"\n",
    "else:\n",
    "    if \"w_bl_attacked_ppl\" in df.columns:\n",
    "        plt.legend(loc = 'lower right',fontsize = 9)\n",
    "    plt.plot([0, 1], [0, 1],'r--')\n",
    "    plt.xlim([0, 1])\n",
    "    plt.ylim([0, 1])\n",
    "    if \"w_bl_attacked_ppl\" in df.columns:\n",
    "        if beams: plot_name = \"roc_auc_untargeted_attack_w_beams\"\n",
    "        if not beams: plot_name = \"roc_auc_untargeted_attack_no_beams\"\n",
    "    else:\n",
    "        plot_name = \"roc_auc\"\n",
    "\n",
    "plt.ylabel('True Positive Rate')\n",
    "plt.xlabel('False Positive Rate')\n",
    "\n",
    "print(plot_name)\n",
    "\n",
    "# fname = f\"figs/{plot_name}.pdf\"\n",
    "# plt.savefig(fname, format=\"pdf\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Z vs T (figure 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.clf()\n",
    "plt.figure(constrained_layout=True)\n",
    "plt.figure(figsize=(5, 4))\n",
    "\n",
    "# save_fig = True\n",
    "save_fig = False\n",
    "\n",
    "z_scores = True\n",
    "# z_scores = False\n",
    "\n",
    "beam_search = None\n",
    "# beam_search = 1\n",
    "# beam_search = 4\n",
    "# beam_search = 8\n",
    "\n",
    "ablate = \"delta\"\n",
    "delta_gammas = [\n",
    "    # (0.5,0.25),\n",
    "    # (1.0,0.25),\n",
    "    # (2.0,0.25),\n",
    "    # (5.0,0.25),\n",
    "    # (10.0,0.25),\n",
    "    (0.5,0.5),\n",
    "    (1.0,0.5),\n",
    "    (2.0,0.5),\n",
    "    (5.0,0.5),\n",
    "    (10.0,0.5),\n",
    "]\n",
    "# ablate = \"gamma\"\n",
    "# delta_gammas = [\n",
    "#     # (5.0,0.9),\n",
    "#     # (5.0,0.75),\n",
    "#     # (5.0,0.5),\n",
    "#     # (5.0,0.25),\n",
    "#     # (5.0,0.1),\n",
    "#     (2.0,0.9),\n",
    "#     (2.0,0.75),\n",
    "#     (2.0,0.5),\n",
    "#     (2.0,0.25),\n",
    "#     (2.0,0.1),\n",
    "# ]\n",
    "# if not z_scores: delta_gammas = delta_gammas[::-1]\n",
    "\n",
    "groups = []\n",
    "names = []\n",
    "\n",
    "for d,g in delta_gammas:\n",
    "        if beam_search:\n",
    "            groups.append((False, beam_search, d, g))\n",
    "        else:\n",
    "            groups.append((True, 1, d, g))\n",
    "        names.append(f\"$\\delta:{d},\\gamma:{g}$\")\n",
    "\n",
    "groups=groups[::-1]\n",
    "names=names[::-1]\n",
    "\n",
    "\n",
    "axis_max_t = 200\n",
    "\n",
    "max_t = None\n",
    "# max_t = 200\n",
    "# max_t = 100\n",
    "# max_t = 50\n",
    "\n",
    "# Make colormap\n",
    "import matplotlib.pyplot as plt\n",
    "viridis = plt.colormaps['viridis'].resampled(len(groups)+1) \n",
    "cmap = viridis.colors[:len(groups)][::-1]\n",
    "\n",
    "for grp_idx,(group, name) in enumerate(zip(groups, names)):\n",
    "\n",
    "    delta, gamma = group[-2],group[-1]\n",
    "\n",
    "    # this is the series of bools corresponding to token at T being in whitelist\n",
    "    w_bl_hit_list = grouped_df.get_group(group)[\"w_bl_hit_list\"].to_list()\n",
    "\n",
    "    lengths = [len(l) for l in w_bl_hit_list]\n",
    "    diff_lengths = set(lengths) \n",
    "    counter = {}\n",
    "    for l in lengths:\n",
    "        if counter.get(l):\n",
    "            counter[l] += 1\n",
    "        else:\n",
    "            counter[l] = 1\n",
    "    if max_t:\n",
    "        min_length = min(min(diff_lengths),max_t)\n",
    "        max_t = min_length\n",
    "    else:\n",
    "        min_length = min(diff_lengths)\n",
    "    w_bl_hit_list = [l[:min_length] for l in w_bl_hit_list]\n",
    "\n",
    "    # wl_hit_matrix = ~np.matrix(w_bl_hit_list)\n",
    "    wl_hit_matrix = (~torch.tensor(w_bl_hit_list, dtype=bool)).to(torch.float)\n",
    "    # wl_hit_matrix\n",
    "\n",
    "    n = wl_hit_matrix.shape[0]\n",
    "\n",
    "    if max_t:\n",
    "        t_values = torch.arange(0,max_t)\n",
    "        indices = torch.arange(0,max_t)\n",
    "    else:\n",
    "        t_values = torch.arange(0,wl_hit_matrix.shape[1])\n",
    "        indices = torch.arange(0,wl_hit_matrix.shape[1])\n",
    "    # print(t_values[:10])\n",
    "\n",
    "    avg_cumulative = list()\n",
    "    std_cumulative = list()\n",
    "    prc_25_cumulative = list()\n",
    "    prc_50_cumulative = list()\n",
    "    prc_75_cumulative = list()\n",
    "\n",
    "    prc_25_seq_indices = list()\n",
    "\n",
    "    for idx in indices:\n",
    "\n",
    "        hits_upto_t = wl_hit_matrix[:,:idx+1]\n",
    "        cumulative_sum_to_t = hits_upto_t.sum(axis=1)\n",
    "        wl_frac_at_t = cumulative_sum_to_t/(t_values[idx]+1)\n",
    "        \n",
    "        if z_scores:\n",
    "            wl_z_score_at_t = compute_z_score(wl_frac_at_t, t_values[idx], gamma)\n",
    "            avg_at_t = torch.mean(wl_z_score_at_t,axis=0)\n",
    "            std_at_t = torch.std(wl_z_score_at_t,axis=0)\n",
    "            prc_25_at_t = torch.quantile(wl_z_score_at_t,q=0.25,axis=0)\n",
    "            prc_50_at_t = torch.quantile(wl_z_score_at_t,q=0.50,axis=0)\n",
    "            prc_75_at_t = torch.quantile(wl_z_score_at_t,q=0.75,axis=0)\n",
    "\n",
    "            if gamma == 0.9: # and idx > 20 and idx < 90:\n",
    "                pcen=np.quantile(wl_z_score_at_t,0.75,interpolation='nearest')\n",
    "                i_near=abs(wl_z_score_at_t-pcen).argmin()\n",
    "                # prc_25_seq_indices.append((i_near.item(),pcen))\n",
    "                prc_25_seq_indices.append((i_near.item()))\n",
    "        else:\n",
    "            avg_at_t = torch.mean(wl_frac_at_t,axis=0)\n",
    "            std_at_t = torch.std(wl_frac_at_t,axis=0)\n",
    "            prc_25_at_t = torch.quantile(wl_frac_at_t,q=0.25,axis=0)\n",
    "            prc_50_at_t = torch.quantile(wl_frac_at_t,q=0.50,axis=0)\n",
    "            prc_75_at_t = torch.quantile(wl_frac_at_t,q=0.75,axis=0)\n",
    "\n",
    "        avg_cumulative.append(avg_at_t.item())\n",
    "        std_cumulative.append(std_at_t.item())\n",
    "        prc_25_cumulative.append(prc_25_at_t.item())\n",
    "        prc_50_cumulative.append(prc_50_at_t.item())\n",
    "        prc_75_cumulative.append(prc_75_at_t.item())\n",
    "\n",
    "\n",
    "    print(prc_25_seq_indices)\n",
    "\n",
    "    avg_cumulative = np.array(avg_cumulative)\n",
    "    std_cumulative = np.array(std_cumulative)\n",
    "    std_err_cumulative = std_cumulative/np.sqrt(n)\n",
    "    var_cumulative = std_cumulative**2\n",
    "    \n",
    "    plt.plot(t_values, avg_cumulative, color=cmap[grp_idx],  label=name)\n",
    "\n",
    "    # bounds stuff\n",
    "\n",
    "    # plt.plot(t_values, prc_25_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
    "    # # plt.plot(t_values, prc_50_cumulative, color=cmap[grp_idx], linestyle='--', label=name+',50th') \n",
    "    # plt.plot(t_values, prc_75_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',75th ') \n",
    "    # #fill between the upper and lower bands\n",
    "    # plt.fill_between(t_values, prc_25_cumulative, prc_75_cumulative, alpha = .1,color = cmap[grp_idx])\n",
    "    # or just lower\n",
    "    # plt.fill_between(t_values, prc_25_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
    "\n",
    "    # plt.plot(t_values, avg_cumulative-std_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
    "    # plt.plot(t_values, avg_cumulative+std_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
    "    # plt.plot(t_values, avg_cumulative-std_err_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
    "    # plt.plot(t_values, avg_cumulative+std_err_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
    "    # plt.plot(t_values, avg_cumulative-var_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
    "    # plt.plot(t_values, avg_cumulative+var_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
    "    # fill between the upper and lower bands\n",
    "    # plt.fill_between(t_values, avg_cumulative-std_cumulative, avg_cumulative+std_cumulative, alpha = .1,color = cmap[grp_idx])\n",
    "    # plt.fill_between(t_values, avg_cumulative-std_err_cumulative, avg_cumulative+std_err_cumulative, alpha = .1,color = cmap[grp_idx])\n",
    "    # or just lower\n",
    "    # plt.fill_between(t_values, avg_cumulative-std_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
    "    # plt.fill_between(t_values, avg_cumulative-std_err_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
    "\n",
    "# plt.plot([0.0],[0.0],label=f'25th Percentile', linestyle=\"dashed\", color=\"gray\")\n",
    "\n",
    "# if beam_search:\n",
    "#     plt.title(f\"Greedy, {beam_search}-way BS\")\n",
    "\n",
    "legend_font = 11\n",
    "\n",
    "# zoom_midrange = True\n",
    "# zoom = True\n",
    "\n",
    "zoom = False\n",
    "\n",
    "if zoom:\n",
    "    if z_scores:\n",
    "        plt.legend(loc = 'upper left', fontsize=legend_font)\n",
    "    else:\n",
    "        plt.legend(loc = 'lower right', fontsize=legend_font)\n",
    "    if zoom_midrange:\n",
    "        plt.xlim([(min_length)/4, (3*(max_t if max_t else min_length)/4)+1])\n",
    "    else:\n",
    "        plt.xlim([0, ((max_t if max_t else min_length)/4)+1])\n",
    "    plot_name = f\"z_vs_t_zoom_ablate_{ablate}\" if z_scores else f\"wl_vs_t_zoom_ablate_{ablate}\"\n",
    "else:\n",
    "    if z_scores:\n",
    "        plt.legend(loc = 'upper left', fontsize=legend_font)\n",
    "    else:\n",
    "        plt.legend(loc = 'lower right', fontsize=legend_font)\n",
    "  \n",
    "    plt.xlim([0, ((max_t if max_t else min_length))+1])\n",
    "\n",
    "    plot_name = f\"z_vs_t_ablate_{ablate}\" if z_scores else f\"wl_vs_t_ablate_{ablate}\"\n",
    "\n",
    "axes_label_fonts = 14\n",
    "if z_scores:\n",
    "    plt.ylabel('z-score',fontsize=axes_label_fonts)\n",
    "else:\n",
    "    plt.ylabel('Whitelist Fraction',fontsize=axes_label_fonts)\n",
    "plt.xlabel('T',fontsize=axes_label_fonts)\n",
    "\n",
    "# import matplotlib.ticker as ticker\n",
    "# tick_spacing = 5.0\n",
    "# plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
    "\n",
    "axes_tick_font = 13\n",
    "plt.xticks(fontsize=axes_tick_font)\n",
    "plt.yticks(fontsize=axes_tick_font)\n",
    "\n",
    "plt.grid()\n",
    "plt.tight_layout()\n",
    "\n",
    "if beam_search:\n",
    "    if ablate == \"gamma\":\n",
    "        plot_name = f\"greedy_{beam_search}_beams_delta_{delta}\" \n",
    "    if ablate == \"delta\":\n",
    "        plot_name = f\"greedy_{beam_search}_beams_gamma_{gamma}\" \n",
    "\n",
    "# plot_name = \"z_vs_t_ablate_gamma_boosted_delta\"\n",
    "# plot_name = \"z_vs_t_ablate_delta_boosted_gamma\"\n",
    "\n",
    "print(plot_name)\n",
    "\n",
    "\n",
    "if save_fig:\n",
    "    # fname = f\"figs/{plot_name}.pdf\"\n",
    "    fname = f\"figs_new/{plot_name}.pdf\"\n",
    "    plt.savefig(fname, format=\"pdf\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set up data for charts (setup for figures 2&7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "viz_df = pd.DataFrame()\n",
    "\n",
    "# aggregating\n",
    "\n",
    "# set the hparam keys, including an indiv column for each you want to ablate on\n",
    "viz_df[\"bl_hparams\"] = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe().index.to_list()\n",
    "for i,key in enumerate(groupby_fields):\n",
    "    viz_df[key] = viz_df[\"bl_hparams\"].apply(lambda tup: tup[i])\n",
    "\n",
    "# viz_df[\"delta\"] = viz_df[\"bl_logit_bias\"].values\n",
    "viz_df[\"gamma\"] = viz_df[\"gamma\"].values\n",
    "# viz_df[\"gamma\"] = np.ones_like(viz_df[\"bl_proportion\"].values) - viz_df[\"bl_proportion\"].values\n",
    "\n",
    "# aggregate each field of interest for each hparam setting (group)\n",
    "describe_dict = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe()\n",
    "viz_df[\"w_bl_exp_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"w_bl_exp_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "describe_dict = grouped_df[\"w_bl_var_whitelist_fraction\"].describe()\n",
    "viz_df[\"w_bl_var_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"w_bl_var_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "describe_dict = grouped_df[\"w_bl_whitelist_fraction\"].describe()\n",
    "viz_df[\"w_bl_whitelist_fraction_min\"] = describe_dict[\"min\"].to_list()\n",
    "viz_df[\"w_bl_whitelist_fraction_25\"] = describe_dict[\"25%\"].to_list()\n",
    "viz_df[\"w_bl_whitelist_fraction_50\"] = describe_dict[\"50%\"].to_list()\n",
    "viz_df[\"w_bl_whitelist_fraction_75\"] = describe_dict[\"75%\"].to_list()\n",
    "viz_df[\"w_bl_whitelist_fraction_max\"] = describe_dict[\"max\"].to_list()\n",
    "viz_df[\"w_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"w_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "describe_dict = grouped_df[\"no_bl_whitelist_fraction\"].describe()\n",
    "viz_df[\"no_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"no_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "\n",
    "describe_dict = grouped_df[\"w_bl_z_score\"].describe()\n",
    "viz_df[\"w_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"w_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "describe_dict = grouped_df[\"no_bl_z_score\"].describe()\n",
    "viz_df[\"no_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"no_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "describe_dict = grouped_df[\"baseline_z_score\"].describe()\n",
    "viz_df[\"baseline_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"baseline_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "\n",
    "describe_dict = grouped_df[\"w_bl_ppl\"].describe()\n",
    "viz_df[\"w_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"w_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "describe_dict = grouped_df[\"no_bl_ppl\"].describe()\n",
    "viz_df[\"no_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"no_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "describe_dict = grouped_df[\"baseline_ppl\"].describe()\n",
    "viz_df[\"baseline_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"baseline_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "describe_dict = grouped_df[\"avg_spike_entropy\"].describe()\n",
    "viz_df[\"avg_spike_entropy_mean\"] = describe_dict[\"mean\"].to_list()\n",
    "viz_df[\"avg_spike_entropy_std\"] = describe_dict[\"std\"].to_list()\n",
    "\n",
    "print(f\"groupby legend: {groupby_fields}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filtering\n",
    "\n",
    "viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == True))] # sampling\n",
    "\n",
    "# viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False))] # greedy\n",
    "\n",
    "\n",
    "# fix one of the bl params for analytic chart\n",
    "# viz_df = viz_df[(viz_df[\"gamma\"]==0.9) & (viz_df[\"delta\"]<=10.0)]\n",
    "# viz_df = viz_df[(viz_df[\"gamma\"]==0.75) & (viz_df[\"delta\"]<=10.0)]\n",
    "# viz_df = viz_df[(viz_df[\"gamma\"]==0.5) & (viz_df[\"delta\"]<=10.0)]\n",
    "# viz_df = viz_df[(viz_df[\"gamma\"]==0.25) & (viz_df[\"delta\"]<=10.0)]\n",
    "# viz_df = viz_df[(viz_df[\"gamma\"]==0.1) & (viz_df[\"delta\"]<=10.0)]\n",
    "\n",
    "# for the sample pareto chart\n",
    "viz_df = viz_df[(viz_df[\"delta\"] > 0.5) & (viz_df[\"delta\"]<=10.0)]\n",
    "# viz_df = viz_df[(viz_df[\"delta\"]<=2.0)] # zoom in on lower deltas\n",
    "# viz_df = viz_df[(viz_df[\"delta\"] >= 2.0) & (viz_df[\"delta\"]<=10.0)] # mid deltas\n",
    "# viz_df = viz_df[(viz_df[\"gamma\"] != 0.25) & (viz_df[\"gamma\"] != 0.75) & (viz_df[\"delta\"]<=2.0)]\n",
    "# viz_df = viz_df[(viz_df[\"gamma\"] != 0.1) & (viz_df[\"gamma\"] != 0.9) & (viz_df[\"delta\"]<=2.0)]\n",
    "\n",
    "# viz_df = viz_df[(viz_df[\"delta\"]==0.5) | (viz_df[\"delta\"]==2.0) | (viz_df[\"delta\"]==10.0)]\n",
    "\n",
    "# viz_df = viz_df[(viz_df[\"delta\"]!=0.1)&(viz_df[\"delta\"]!=0.5)&(viz_df[\"delta\"]!=50.0)]\n",
    "\n",
    "# for the beams pareto\n",
    "# viz_df = viz_df[(viz_df[\"delta\"]!=50.0)]\n",
    "# viz_df = viz_df[(viz_df[\"delta\"]!=50.0) & (viz_df[\"num_beams\"]!=1)]\n",
    "\n",
    "print(len(viz_df))\n",
    "\n",
    "viz_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# grouped_df[\"avg_spike_entropy\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# viz_df[[\"gamma\",\"avg_spike_entropy_mean\"]]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic Exp vs Empirical WL fraction chart (figure 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# plt.style.use(\"classic\")\n",
    "plt.style.use(\"default\")\n",
    "# plt.style.use('ggplot') \n",
    "# plt.style.use('seaborn')\n",
    "\n",
    "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
    "rc('text', usetex=True)\n",
    "\n",
    "\n",
    "plt.clf()\n",
    "# plt.figure(figsize=(16, 4))\n",
    "# plt.figure(figsize=(8, 4))\n",
    "plt.figure(constrained_layout=True)\n",
    "plt.figure(figsize=(5, 4))\n",
    "\n",
    "\n",
    "# x_col = 'bl_hparams'\n",
    "# a = viz_df[x_col].apply(str)\n",
    "\n",
    "# x_col = 'bl_logit_bias'\n",
    "# x_col = 'bl_proportion'\n",
    "x_col = \"delta\"\n",
    "# x_col = \"gamma\"\n",
    "\n",
    "a = viz_df[x_col]\n",
    "print(f\"Num configurations: {len(a)}\")\n",
    "\n",
    "y_col = 'w_bl_whitelist_fraction_mean'\n",
    "y_col_err = 'w_bl_whitelist_fraction_std'\n",
    "\n",
    "viridis = plt.colormaps['viridis'].resampled(4)\n",
    "# cmap = viridis.colors[::-1]\n",
    "cmap = viridis.colors\n",
    "\n",
    "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_mean\"].values, color=cmap[1], marker='o', label='Mean') \n",
    "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_25\"].values, color=cmap[1], linestyle='-.', label='25th Percentile') \n",
    "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_75\"].values, color=cmap[1], linestyle='-.', label='75th Percentile') \n",
    "# plt.plot(a, viz_df[\"w_bl_whitelist_fraction_min\"].values, color=cmap[1], linestyle='-.', label='min') \n",
    "# plt.plot(a, viz_df[\"w_bl_whitelist_fraction_max\"].values, color=cmap[1], linestyle='-.', label='max') \n",
    "\n",
    "#fill between the upper and lower bands\n",
    "plt.fill_between(a, viz_df[\"w_bl_whitelist_fraction_25\"], viz_df[\"w_bl_whitelist_fraction_75\"], alpha = .1,color = cmap[1])\n",
    "# plt.fill_between(a, viz_df[\"w_bl_whitelist_fraction_25\"], viz_df[\"w_bl_whitelist_fraction_75\"], alpha = .1,color = 'darkorchid')\n",
    "# plt.fill_between(a, y1_low, y1_high, alpha = .1,color = 'goldenrod')\n",
    "\n",
    "\n",
    "y_col = 'w_bl_exp_whitelist_fraction_mean'\n",
    "# y_col_err = 'w_bl_var_whitelist_fraction_mean'\n",
    "# d = viz_df[x_col].apply(str)\n",
    "\n",
    "# sub_df = viz_df[viz_df[\"num_beams\"]==1]\n",
    "\n",
    "a = viz_df[x_col]\n",
    "e = viz_df[y_col].values\n",
    "# plt.plot(a, e, label=\"Predicted Lower Bound\", color=cmap[-1])\n",
    "plt.plot(a, e, label=\"Analytic Bound\", color=\"r\")\n",
    "# f = viz_df[y_col_err].values\n",
    "# # f = np.sqrt(viz_df[y_col_err].values)\n",
    "# plt.errorbar(d, e, yerr=f, fmt=\"o\")\n",
    "\n",
    "plt.legend(loc=\"lower right\",frameon=True, facecolor=\"white\")\n",
    "\n",
    "# for logit bias x axis\n",
    "# log_axis = True\n",
    "log_axis = False\n",
    "if log_axis:\n",
    "    plt.xscale(\"log\")\n",
    "\n",
    "ax = plt.gca()\n",
    "plt.draw()\n",
    "\n",
    "\n",
    "\n",
    "plt.xlabel(f\"Green List Bias, $\\delta$\")\n",
    "# plt.xlabel(f\"Whitelist size := $\\gamma$\")\n",
    "\n",
    "plt.ylabel(\"Fraction in Green List\")\n",
    "\n",
    "\n",
    "plt.grid()\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "if log_axis:\n",
    "    plot_name = \"analytic_w_sampling_log.pdf\"\n",
    "else:\n",
    "    plot_name = \"analytic_w_sampling_linear.pdf\"\n",
    "    # plot_name = f\"analytic_w_sampling_linear_gamma_{viz_df['gamma'].values[0]}.pdf\"\n",
    "\n",
    "# plot_name = \"analytic_w_sampling_linear_greenlist.pdf\"\n",
    "print(plot_name)\n",
    "\n",
    "# fname = f\"figs/{plot_name}\"\n",
    "# plt.savefig(fname, format=\"pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# delta gamma sampling pareto plot (figure 2 left)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
    "rc('text', usetex=True)\n",
    "\n",
    "plt.clf()\n",
    "plt.figure(constrained_layout=True)\n",
    "plt.figure(figsize=(5, 4))\n",
    "\n",
    "\n",
    "x_col = 'w_bl_ppl_mean'\n",
    "y_col = 'w_bl_z_score_mean'\n",
    "\n",
    "# markers = [\"x\", \"p\", \"*\", \"P\"]\n",
    "\n",
    "deltas = sorted(np.unique(viz_df[\"delta\"].values))\n",
    "gammas = sorted(np.unique(viz_df[\"gamma\"].values), reverse=True)\n",
    "print(deltas, gammas)\n",
    "gamma_labels = [(g if g > 0.1 else 0.1) for g in gammas]\n",
    "\n",
    "markers = [\"x\", \"p\", \"*\", \"P\"][:len(deltas)]\n",
    "\n",
    "num_colors = len(gammas)\n",
    "cmap = cmr.get_sub_cmap('viridis', 0.0, 0.66, N=num_colors)\n",
    "# cmap = cmr.get_sub_cmap('plasma', 0.0, 0.66, N=num_colors)\n",
    "colors = cmap.colors#[::-1]\n",
    "\n",
    "\n",
    "for i,delta in enumerate(deltas):\n",
    "    for j,gamma in enumerate(gammas):\n",
    "        sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"gamma\"] == gamma)]\n",
    "        a = sub_df[x_col].values\n",
    "        b = sub_df[y_col].values\n",
    "        # plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
    "        plt.plot(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
    "\n",
    "\n",
    "x_col = 'no_bl_ppl_mean'\n",
    "y_col = 'no_bl_z_score_mean'\n",
    "# x_col = 'baseline_ppl_mean'\n",
    "# y_col = 'baseline_z_score_mean'\n",
    "\n",
    "\n",
    "for i,delta in enumerate(deltas):\n",
    "    for j,gamma in enumerate(gammas):\n",
    "        sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"gamma\"] == gamma)]\n",
    "        a = sub_df[x_col].values\n",
    "        b = sub_df[y_col].values\n",
    "        plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j])\n",
    "\n",
    "# # # for manual legend\n",
    "plt.scatter([-1],[-1], label=\"Vanilla\", color=\"gray\", marker=\"o\")\n",
    "\n",
    "ax = plt.gca()\n",
    "\n",
    "from matplotlib.cm import ScalarMappable\n",
    "from matplotlib.colors import Normalize, NoNorm, ListedColormap\n",
    "cmap = ListedColormap(colors)\n",
    "cmappable = ScalarMappable(norm=NoNorm(),cmap=cmap)\n",
    "cbar = plt.colorbar(cmappable,ticks=[i for i in range(len(gammas))],shrink=0.6, pad = 0.03)\n",
    "cbar.ax.set_yticklabels(gamma_labels) \n",
    "cbar.set_label('$\\gamma$', rotation=0)\n",
    "\n",
    "\n",
    "all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['no_bl_ppl_mean'].values])\n",
    "all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['no_bl_z_score_mean'].values])\n",
    "# all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['baseline_ppl_mean'].values])\n",
    "# all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['baseline_z_score_mean'].values])\n",
    "\n",
    "min_x, max_x = np.min(all_x), np.max(all_x)\n",
    "min_y, max_y = np.min(all_y), np.max(all_y)\n",
    "\n",
    "# x_min_tick = 1.0\n",
    "x_min_tick = 3.0\n",
    "x_max_tick = np.ceil([max_x])[0]+1.0\n",
    "y_min_tick = 0.0\n",
    "y_max_tick = np.ceil([max_y])[0]+1.0\n",
    "\n",
    "x_ticks = np.arange(x_min_tick,x_max_tick,1.0)\n",
    "y_ticks = np.arange(y_min_tick,y_max_tick,5.0)\n",
    "\n",
    "\n",
    "x_lim_min = 3.0\n",
    "x_lim_max = x_max_tick\n",
    "y_lim_min = 0.45\n",
    "# y_lim_max = 1.09\n",
    "y_lim_max = 1.005\n",
    "\n",
    "\n",
    "# plt.xlim((x_min_tick-0.5,x_max_tick))\n",
    "plt.xlim((x_lim_min,x_lim_max))\n",
    "# plt.xlim((4.0,8.0))\n",
    "# plt.ylim((-1.0,20.0))\n",
    "# plt.ylim((y_lim_min,y_lim_max))\n",
    "\n",
    "ax.set_xticks(x_ticks)\n",
    "# ax.set_yticks(y_ticks)\n",
    "\n",
    "ax.invert_xaxis()\n",
    "\n",
    "# # manual legend for dual parameter visualization\n",
    "f = lambda m,c: plt.plot([],[],marker=m, color=c, ls=\"none\")[0]\n",
    "handles = [f(markers[::-1][i], \"gray\") for i in range(len(deltas))]\n",
    "handles += [f(\"o\", \"gray\")]\n",
    "labels = [f\"$\\delta={delta}$\" for delta in deltas[::-1]]+[f\"$\\delta=0.0$\"]\n",
    "plt.legend(handles, labels, loc=\"upper right\", framealpha=1)\n",
    "\n",
    "plt.grid()\n",
    "\n",
    "plt.xlabel(\"Oracle Model PPL (better →)\")\n",
    "plt.ylabel(\"z-score (better →)\")\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# plot_name = \"pareto_sampling_no_beams\"\n",
    "# fname = f\"figs/{plot_name}.pdf\"\n",
    "# plt.savefig(fname, format=\"pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# beams pareto plot (figure 2 right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_colors = 3\n",
    "cmap = cmr.get_sub_cmap('viridis', 0.0, 0.66, N=num_colors)\n",
    "colors = cmap.colors#[::-1]\n",
    "\n",
    "# plt.style.use('ggplot')\n",
    "# plt.style.use('seaborn')\n",
    "\n",
    "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
    "rc('text', usetex=True)\n",
    "\n",
    "plt.clf()\n",
    "plt.figure(constrained_layout=True)\n",
    "plt.figure(figsize=(5, 4))\n",
    "\n",
    "\n",
    "x_col = 'w_bl_ppl_mean'\n",
    "y_col = 'w_bl_z_score_mean'\n",
    "\n",
    "markers = [\"s\",\"D\", \"x\", \"p\",  \"*\", \"P\"] # <--- seems to match other pareto fig ordering\n",
    "\n",
    "deltas = sorted(np.unique(viz_df[\"delta\"].values))\n",
    "num_beams = sorted(np.unique(viz_df[\"num_beams\"].values))\n",
    "# gamma_labels = [(g if g > 0.1 else 0.1) for g in np.unique(viz_df[\"gamma\"].values)]\n",
    "\n",
    "for i,n_beams in enumerate(num_beams):\n",
    "    for j,delta in enumerate(deltas):\n",
    "        sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"num_beams\"] == n_beams)]\n",
    "        a = sub_df[x_col].values\n",
    "        b = sub_df[y_col].values\n",
    "        # plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
    "        plt.plot(a, b, label=f\"$\\delta={delta}$\", color=colors[i], marker=markers[j])\n",
    "\n",
    "\n",
    "x_col = 'no_bl_ppl_mean'\n",
    "y_col = 'no_bl_z_score_mean'\n",
    "\n",
    "\n",
    "\n",
    "for i,n_beams in enumerate(num_beams):\n",
    "    for j,delta in enumerate(deltas):\n",
    "        sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"num_beams\"] == n_beams)]\n",
    "        a = sub_df[x_col].values\n",
    "        b = sub_df[y_col].values\n",
    "        plt.scatter(a, b, label=f\"$\\delta={delta}$\", color=colors[i])\n",
    "\n",
    "# # # for manual legend\n",
    "plt.scatter([-10],[-10], label=\"$\\delta=0$\", color=\"gray\", marker=\"o\")\n",
    "\n",
    "ax = plt.gca()\n",
    "\n",
    "from matplotlib.cm import ScalarMappable\n",
    "from matplotlib.colors import Normalize, NoNorm, ListedColormap\n",
    "cmap = ListedColormap(colors)\n",
    "cmappable = ScalarMappable(norm=NoNorm(),cmap=cmap)\n",
    "cbar = plt.colorbar(cmappable,ticks=[i for i in range(len(num_beams))],shrink=0.6, pad = 0.04)\n",
    "# cbar.set_ticks(num_beams)\n",
    "cbar.set_ticklabels(num_beams)\n",
    "# cbar.ax.set_yticklabels(num_beams) \n",
    "cbar.set_label('Num Beams', rotation=90)\n",
    "\n",
    "\n",
    "all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['no_bl_ppl_mean'].values])\n",
    "all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['no_bl_z_score_mean'].values])\n",
    "\n",
    "min_x, max_x = np.min(all_x), np.max(all_x)\n",
    "min_y, max_y = np.min(all_y), np.max(all_y)\n",
    "\n",
    "# x_max_tick = np.ceil([max_x])[0]+1.0\n",
    "x_max_tick = np.ceil([max_x])[0]\n",
    "y_max_tick = np.ceil([max_y])[0]+1.0\n",
    "\n",
    "\n",
    "plt.xlim((1.0,x_max_tick))\n",
    "plt.ylim((-1.0,y_max_tick))\n",
    "\n",
    "# x_ticks = np.arange(x_min_tick,x_max_tick,1.0)\n",
    "# y_ticks = np.arange(y_min_tick,y_max_tick,5.0)\n",
    "\n",
    "# ax.set_xticks(x_ticks)\n",
    "# ax.set_yticks(y_ticks)\n",
    "\n",
    "ax.invert_xaxis()\n",
    "\n",
    "# # manual legend for dual parameter visualization\n",
    "f = lambda m,c: plt.plot([],[],marker=m, color=c, ls=\"none\")[0]\n",
    "handles = [f(markers[::-1][i], \"gray\") for i in range(len(deltas))]\n",
    "handles += [f(\"o\", \"gray\")]\n",
    "labels = [f\"$\\delta={delta}$\" for delta in deltas[::-1]]+[f\"$\\delta=0.0$\"]\n",
    "plt.legend(handles, labels, loc=\"lower left\", framealpha=1)\n",
    "\n",
    "plt.grid()\n",
    "\n",
    "plt.xlabel(\"Oracle Model PPL (better →)\")\n",
    "plt.ylabel(\"z-score (better →)\")\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "\n",
    "plot_name = \"pareto_greedy_w_beams\"\n",
    "print(plot_name)\n",
    "\n",
    "# fname = f\"figs/{plot_name}.pdf\"\n",
    "# plt.savefig(fname, format=\"pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## z vs entropy (not in paper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"groupby legend: {groupby_fields}\")\n",
    "# hist_subset = grouped_df.get_group((True,1,2.0,0.1)) # needs to match the groupby keys and order\n",
    "# hist_subset = grouped_df.get_group((True,1,2.0,0.25)) \n",
    "hist_subset = grouped_df.get_group((True,1,2.0,0.5)) \n",
    "# hist_subset = grouped_df.get_group((True,1,2.0,0.75)) \n",
    "# hist_subset = grouped_df.get_group((True,1,2.0,0.9)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(hist_subset))\n",
    "# hist_subset = hist_subset[hist_subset[\"w_bl_space_frac\"] <= 0.9]\n",
    "# hist_subset = hist_subset[hist_subset[\"no_bl_space_frac\"] <= 0.9]\n",
    "# print(len(hist_subset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y = hist_subset[\"w_bl_z_score\"]\n",
    "# y = hist_subset[\"no_bl_z_score\"]\n",
    "y = hist_subset[\"baseline_z_score\"]\n",
    "\n",
    "x = hist_subset[\"avg_spike_entropy\"]\n",
    "\n",
    "plt.clf()\n",
    "\n",
    "\n",
    "plt.scatter(x, y)\n",
    "\n",
    "\n",
    "plt.grid()\n",
    "\n",
    "plt.xlabel(\"Entropy\")\n",
    "plt.ylabel(\"z-score\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cols_to_tabulate = [\n",
    "    'idx', \n",
    "    'truncated_input', \n",
    "    'baseline_completion',\n",
    "    'no_bl_output', \n",
    "    'w_bl_output', \n",
    "    'avg_spike_entropy',\n",
    "    'no_bl_z_score',\n",
    "    'w_bl_z_score',\n",
    "    'w_bl_whitelist_fraction',\n",
    "    'no_bl_whitelist_fraction',\n",
    "    'baseline_ppl',\n",
    "    'no_bl_ppl',\n",
    "    'w_bl_ppl'\n",
    "]\n",
    "\n",
    "slice_size = 10\n",
    "\n",
    "num_examples = len(hist_subset)\n",
    "midpt = num_examples//5\n",
    "lower = midpt - (slice_size//2)\n",
    "upper = midpt + (slice_size//2)+1\n",
    "\n",
    "high_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).tail(slice_size)\n",
    "mid_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).iloc[lower:upper]\n",
    "low_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).head(slice_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=14.0)]\n",
    "hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"baseline_z_score\"]>=7.0)]\n",
    "# hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=12.0)]\n",
    "# print(hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=14.0)].iloc[6][\"w_bl_output\"])\n",
    "# .to_csv(\"input/pile_low_S_high_z_outliers.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "high_entropy_examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mid_entropy_examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "low_entropy_examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# plotting histograms of the metric for single runs (not in paper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"groupby legend: {groupby_fields}\")\n",
    "# hist_subset = grouped_df.get_group((True,1,2.0,0.1)) # needs to match the groupby keys and order\n",
    "# hist_subset = grouped_df.get_group((True,1,2.0,0.25)) \n",
    "hist_subset = grouped_df.get_group((True,1,2.0,0.5)) \n",
    "# hist_subset = grouped_df.get_group((True,1,2.0,0.75)) \n",
    "# hist_subset = grouped_df.get_group((True,1,2.0,0.9)) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#####  old filters to smooth the histograms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hist_subset = hist_subset[(hist_subset[\"no_bl_num_tokens_generated\"] == hist_subset[\"max_new_tokens\"]) & (hist_subset[\"w_bl_num_tokens_generated\"] == hist_subset[\"max_new_tokens\"])]\n",
    "# hist_subset = hist_subset[hist_subset[\"truncated_input\"] != \"\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_no_bl_wl_fractions = hist_subset[\"no_bl_whitelist_fraction\"]\n",
    "all_w_bl_wl_fractions = hist_subset[\"w_bl_whitelist_fraction\"]\n",
    "all_baseline_wl_fractions = hist_subset[\"baseline_whitelist_fraction\"]\n",
    "# all_no_bl_wl_fractions = hist_subset[\"no_bl_z_score\"]\n",
    "# all_w_bl_wl_fractions = hist_subset[\"w_bl_z_score\"]\n",
    "# all_baseline_wl_fractions = hist_subset[\"baseline_z_score\"]\n",
    "\n",
    "plt.clf()\n",
    "\n",
    "all_vals = np.concatenate([all_baseline_wl_fractions, all_w_bl_wl_fractions, all_no_bl_wl_fractions])\n",
    "n_bins = 50\n",
    "bins = np.linspace(np.min(all_vals), np.max(all_vals), n_bins)\n",
    "# bins = np.linspace(0.0, 1.0, n_bins)\n",
    "\n",
    "# plt.hist(all_no_bl_wl_fractions, \n",
    "#         bins=bins,\n",
    "#         alpha=0.6,\n",
    "#         label='no blacklisting')\n",
    "\n",
    "\n",
    "plt.hist(all_w_bl_wl_fractions, \n",
    "        bins=bins,\n",
    "        alpha=0.6,\n",
    "        label='with blacklisting')\n",
    "\n",
    "plt.hist(all_baseline_wl_fractions,\n",
    "        bins=bins,\n",
    "        alpha=0.4,\n",
    "        # label='wl')\n",
    "        label='ground truth/real text')\n",
    "\n",
    "# plt.hist(all_baseline_bl_fractions, \n",
    "#         bins=bins,\n",
    "#         alpha=0.5,\n",
    "#         label='bl')\n",
    "\n",
    "plt.legend(loc='upper right')\n",
    "\n",
    "# plt.xlim((-0.1,1.1))\n",
    "# plt.xticks(np.arange(0.0,1.0,0.1))\n",
    "plt.xlabel(\"fraction of total toks gen'd in WL\")\n",
    "plt.ylabel(\"freq\")\n",
    "\n",
    "# plt.title('baseline wl/bl fractions')\n",
    "plt.title(\"Output Whitelist Token Distribution\")\n",
    "\n",
    "# plot_name = \"wl_distro\"\n",
    "# fname = f\"figs/{plot_name}.png\"\n",
    "# plt.savefig(fname, dpi=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.clf()\n",
    "\n",
    "all_no_bl_ppls = hist_subset[\"no_bl_ppl\"]\n",
    "all_w_bl_ppls = hist_subset[\"w_bl_ppl\"]\n",
    "all_baseline_ppls = hist_subset[\"baseline_ppl\"]\n",
    "\n",
    "all_vals = list(np.concatenate([all_no_bl_ppls, all_w_bl_ppls]))\n",
    "all_vals = sorted(all_vals)\n",
    "n_bins = 50\n",
    "# bins = np.linspace(all_vals[0], all_vals[-1], n_bins)\n",
    "bins = np.linspace(all_vals[0], 20, n_bins)\n",
    "\n",
    "plt.hist(all_no_bl_ppls, \n",
    "        bins=bins,\n",
    "        alpha=0.6,\n",
    "        label='no blacklisting')\n",
    "\n",
    "plt.hist(all_w_bl_ppls, \n",
    "        bins=bins,\n",
    "        alpha=0.6,\n",
    "        label='with blacklisting')\n",
    "\n",
    "plt.legend(loc='upper right')\n",
    "\n",
    "# plt.xlim((0,1))\n",
    "plt.xlabel(\"perplexity (lower is better)\")\n",
    "plt.ylabel(\"freq\")\n",
    "\n",
    "plt.title('Model-based Output Quality/Fluency')\n",
    "\n",
    "# plot_name = \"ppl_no_baseline\"\n",
    "# fname = f\"figs/{plot_name}.png\"\n",
    "# plt.savefig(fname, dpi=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.clf()\n",
    "\n",
    "all_vals = list(np.concatenate([all_no_bl_ppls, all_w_bl_ppls]))\n",
    "all_vals = sorted(all_vals)\n",
    "n_bins = 50\n",
    "# bins = np.linspace(all_vals[0], all_vals[-1], n_bins)\n",
    "bins = np.linspace(all_vals[0], 20, n_bins)\n",
    "\n",
    "plt.hist(all_no_bl_ppls, \n",
    "        bins=bins,\n",
    "        alpha=0.6,\n",
    "        label='no blacklisting')\n",
    "\n",
    "plt.hist(all_w_bl_ppls, \n",
    "        bins=bins,\n",
    "        alpha=0.6,\n",
    "        label='with blacklisting')\n",
    "\n",
    "plt.hist(all_baseline_ppls, \n",
    "        bins=bins,      \n",
    "        alpha=0.4,\n",
    "        label='ground truth/real text')\n",
    "\n",
    "plt.legend(loc='upper right')\n",
    "\n",
    "# plt.xlim((0,1))\n",
    "plt.xlabel(\"perplexity (lower is better)\")\n",
    "plt.ylabel(\"freq\")\n",
    "\n",
    "plt.title('Model-based Output Quality/Fluency')\n",
    "\n",
    "# plot_name = \"ppl_w_baseline\"\n",
    "# fname = f\"figs/{plot_name}.png\"\n",
    "# plt.savefig(fname, dpi=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "365524a309ad80022da286f2ec5d2060ce5cb229abb6076cf68d9a1ab14bd8fe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
